#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 20 11:09:03 2018

@author: annabelle

"""
import pickle
import os
from stringCore import *
from christiescripts.stringtheory_experiment.plotcorrdope import getcorr


def findBestWindow(amsb, window):
    mzs_system = []

    # return images ready for postselection and string pattern detection
    ms = []
    mDiffs = []
    mzs_window = []

    for img in amsb:
        m = img.copy()

        roi = np.abs(m)
        totalSites = np.sum(roi)

        # get staggered magnetization
        sm = 0;
        for i in range(m.shape[0]):
            for j in range(m.shape[1]):
                sm += (-1) ** (i + j) * m[i, j]
        mzs_system.append(sm * 1.0 / totalSites)

        if window == 'fixed':
            mask = np.zeros_like(m)
            x0, y0 = 2, 2
            mask[x0 + 0, y0 + 2:y0 + 5] = 1
            mask[x0 + 1, y0 + 1:y0 + 6] = 1
            mask[x0 + 2, y0 + 0:y0 + 7] = 1
            mask[x0 + 3] = mask[x0 + 2]
            mask[x0 + 4] = mask[x0 + 2]
            mask[x0 + 5] = mask[x0 + 1]
            mask[x0 + 6] = mask[x0]

            m1 = m.copy()
            m1[mask == 0] = 0
            totalSites_mask = np.sum(mask)
            totalSites = np.sum(abs(m1))
            if totalSites != totalSites_mask:
                print 'Error with fixed ROI!'
                break
            else:
                # get staggered magnetization
                mz = 0
                for i in range(m1.shape[0]):
                    for j in range(m1.shape[1]):
                        mz += (-1) ** (i + j) * m1[i, j]

        else:
            m_temps = []
            smz_temps = []

            for x0 in range(1, m.shape[0]-5):  # 5 is hardcoded in because it's the width of the small window
                for y0 in range(1, m.shape[1]-5):  # 5 is hardcoded in because it's the height of the small window
                    mask = np.zeros_like(m)

                    if window == 'big':
                        try:
                            mask[x0 + 0, y0 + 2:y0 + 6] = 1
                            mask[x0 + 1, y0 + 1:y0 + 7] = 1
                            mask[x0 + 2, y0 + 0:y0 + 8] = 1
                            mask[x0 + 3] = mask[x0 + 2]
                            mask[x0 + 4] = mask[x0 + 2]
                            mask[x0 + 5] = mask[x0 + 2]
                            mask[x0 + 6] = mask[x0 + 1]
                            mask[x0 + 7] = mask[x0 + 0]
                        except IndexError:
                            continue
                    elif window == 'small':
                        try:
                            mask[x0 + 0, y0 + 1:y0 + 4] = 1
                            mask[x0 + 1, y0 + 0:y0 + 5] = 1
                            mask[x0 + 2] = mask[x0 + 1]
                            mask[x0 + 3] = mask[x0 + 1]
                            mask[x0 + 4] = mask[x0 + 0]
                        except IndexError:
                            continue
                    else:
                        try:
                            mask[x0 + 0, y0 + 2:y0 + 5] = 1
                            mask[x0 + 1, y0 + 1:y0 + 6] = 1
                            mask[x0 + 2, y0 + 0:y0 + 7] = 1
                            mask[x0 + 3] = mask[x0 + 2]
                            mask[x0 + 4] = mask[x0 + 2]
                            mask[x0 + 5] = mask[x0 + 1]
                            mask[x0 + 6] = mask[x0]
                        except IndexError:
                            continue

                    m1 = m.copy()
                    m1[mask == 0] = 0
                    totalSites_mask = np.sum(mask)
                    totalSites = np.sum(abs(m1))
                    if totalSites == totalSites_mask:  # only take position if window completely in big ROI
                        # get staggered magnetization
                        sm = 0
                        for i in range(m1.shape[0]):
                            for j in range(m1.shape[1]):
                                sm += (-1) ** (i + j) * m1[i, j]
                        m_temps.append(m1)
                        smz_temps.append(sm)

            indices = sorted(range(len(smz_temps)), key=lambda k: -abs(smz_temps[k]))
            m1 = m_temps[indices[0]]

        # cut out window from system
        roi = np.where(m1 != 0)
        # find corners of big roi
        xMin = min(np.unique(roi[0]))
        xMax = max(np.unique(roi[0]))
        yMin = min(np.unique(roi[1]))
        yMax = max(np.unique(roi[1]))

        m1 = m1[xMin:xMax + 1, yMin:yMax + 1]

        # get staggered magnetization of small region
        mz = 0
        for i in range(m1.shape[0]):
            for j in range(m1.shape[1]):
                mz += (-1) ** (i + j) * m1[i, j]
        mzs_window.append(mz)

        # initiate reference Neel state
        ref = np.ones_like(m1)
        for i in range(m1.shape[0]):
            for j in range(m1.shape[1]):
                ref[i, j] = (-1) ** (i + j) * np.sign(mz)
        ref[m1 == 0] = 0

        ms.append(m1)
        mDiffs.append(abs(m1 - ref))

    return mzs_system, ms, mDiffs, mzs_window


def shotEval(dir, case, gethist, window, ps, label):
    # This is a really longwinded way to get cold, hot, D
    with open("V:\Paper Data\AFM_StringTheory\Snapshots_experiment\{}_params.pkl".format(
            ''.join(i for i in case.split('_')[-1] if not i.isdigit())), 'rb') as f:
        params = pickle.load(f)

    if 'D' in case:
        rounded_dopings = [int(case[case.index('D')+1:])]
        # if simulated data
        if len(case.split('_')) > 1:
            rounded_temperatures = [round(T, 1) for T in params['temperature']['0']]
        # if experiment
        else:
            rounded_temperatures = [round(T, 1) for T in params['temperature'][str(rounded_dopings[0])]]
    else:
        rounded_dopings = [int(round(d / 2.) * 2) for d in params['doping']]
        rounded_temperatures = [round(T, 1) for T in params['temperature']]

    dopings = []
    dopingerrs = []
    asls = []
    asles = []
    scs = []
    sces = []
    scs0 = []
    sces0 = []
    mzs = []
    cfds = []
    cfs = []
    cfes = []
    hhs = []
    hhes = []
    g2ns = []
    g2nes = []
    temperatures = []
    temperatureerrs = []

    for fn in os.listdir(dir):
        if fn.startswith(case) and not fn.endswith('params.pkl'):  # filter for correct dataset
            print "Evaluating {} for {} {}".format(fn, case, label)

            # if sorted by doping and temperature, use temperature in filename
            if 'D' in case:
                rounddope = int(fn.split('_')[-2][1:])
                roundtemperature = float(fn.split('_')[-1][1:-4])
            # if experiment and sorted hot/cold
            else:
                rounddope = int(fn.split('_')[-1][1:-4])
                roundtemperature = round(params['temperature'][0], 1)

            pdopes = params['doping'][str(rounddope)] if 'D' in case else params['doping']
            pdopeerrs = params['dopingerr'][str(rounddope)] if 'D' in case else params['dopingerr']
            if len(case.split('_')) > 1:
                ptemps = params['temperature']['0'] if 'D' in case else params['temperature']
                ptemperrs = params['temperatureerr']['0'] if 'D' in case else params['temperatureerr']
            else:
                ptemps = params['temperature'][str(rounddope)] if 'D' in case else params['temperature']
                ptemperrs = params['temperatureerr'][str(rounddope)] if 'D' in case else params['temperatureerr']

            # if simulated data, use the half-filling doping error, but the exact doping value
            if len(case.split('_')) > 1:
                # If varying temperature at fixed doping, we don't need to look at doping or its error
                if 'D' in case:
                    dopings.append(rounddope)
                    dopingerrs.append(0)
                else:
                    dopings.append(rounddope)
                    dopingerrs.append(pdopeerrs[rounded_dopings.index(0)])
            # if experiment
            else:
                dopings.append(pdopes[rounded_dopings.index(rounddope)])
                dopingerrs.append(pdopeerrs[rounded_dopings.index(rounddope)])

            temperatures.append(ptemps[rounded_temperatures.index(roundtemperature)])
            temperatureerrs.append(ptemperrs[rounded_temperatures.index(roundtemperature)])
            print "  Doping      {} +/- {}".format(dopings[-1], dopingerrs[-1])
            print "  Temperature {} +/- {}".format(temperatures[-1], temperatureerrs[-1])

            with open(dir+fn, 'rb') as f:
                ams = pickle.load(f)

            # get spin correlations
            xs, ys, es = getcorr(ams, 'ss')
            cfs.append(ys)
            cfes.append(es)
            cfds.append(xs)

            # get moment correlations
            _, ys, es = getcorr(ams, 'hh')
            hhs.append(ys)
            hhes.append(es)
            if dopings[-1] == 0:
                g2ns.append(np.zeros_like(ys))
                g2nes.append(np.zeros_like(ys))
            else:
                _, ys, es = getcorr(ams, 'g2n', d=dopings[-1]/100.)
                g2ns.append(ys)
                g2nes.append(es)

            amsb = list(ams[1]) + list(ams[2])

            mzs_system, ms, mDiffs, mzs_window = findBestWindow(amsb, window)

            print '  Sorting snapshots!'
            indices = sorted(range(len(mzs_window)), key=lambda k: -abs(mzs_window[k]))
            mzs_window = [mzs_window[i] for i in indices]
            mDiffs = [mDiffs[i] for i in indices]
            ms = [ms[i] for i in indices]

            numberShots = int(ps * len(mzs_window) + 0.5)
            print '    Keeping best {} snapshots of all {}'.format(numberShots, len(mDiffs))
            print '    Lowest staggered magnetization accepted: ' + str(abs(mzs_window[numberShots - 1]))

            lengthFCS = []
            print '  Extracting string patterns!'
            for l in range(numberShots):
                if l % 500 == 0:
                    print '    ...searching shot no. ' + str(l)

                mDiff = np.copy(mDiffs[l])
                m = np.copy(ms[l])
                stringLengths_temp = []

                # find all red sites initially
                redSites = np.transpose(np.where(mDiff == 2)).tolist()

                # group all red sites into red areas, and from these look for and undo string patterns
                while len(redSites) > 0:
                    redAreas = []
                    while len(redSites) > 0:
                        red = redSites[0]
                        redArea = findRedArea(red, [], redSites, mDiff, mDiff.shape[0])
                        redAreas.append(redArea)
                        for redSite in redArea:
                            redSites.remove(redSite)

                        longestLength, darkFlip, m, mDiff = possibleString(redArea, m, mDiff)
                        if not darkFlip:
                            stringLengths_temp.append(longestLength)

                    redSites = np.transpose(np.where(mDiff == 2)).tolist()  # find all remaining red sites and repeat

                lMax = 8
                lengthFCS.append([stringLengths_temp.count(i) * 1.0/np.count_nonzero(m) for i in range(lMax+1)])

            # get string length distribution
            lengthFCS = np.transpose(lengthFCS)
            lengthFCS_mean = []
            lengthFCS_std = []
            for i in range(lMax + 1):
                lengthFCS_mean.append(np.mean(lengthFCS[i]))
                lengthFCS_std.append(np.std(lengthFCS[i]) * 1.0 / np.sqrt(numberShots - 1))
            lengthFCS_mean = np.array(lengthFCS_mean)
            lengthFCS_std = np.array(lengthFCS_std)

            # get average string length
            st_len = np.sum(lengthFCS_mean * np.arange(len(lengthFCS_mean))) / np.sum(lengthFCS_mean)
            # standard error prop., assuming measured string lengths independent.
            st_len_err = np.sqrt(np.sum(((np.arange(len(lengthFCS_mean)) - st_len) / np.sum(lengthFCS_mean)) ** 2 * lengthFCS_std ** 2))
            asls.append(st_len)
            asles.append(st_len_err)
            print '    String length: ', st_len, st_len_err

            # get string count
            scs.append(np.sum(lengthFCS_mean[2:]))
            sces.append(np.sqrt(np.sum(lengthFCS_std[2:]**2)))
            scs0.append(np.sum(lengthFCS_mean))
            sces0.append(np.sqrt(np.sum(lengthFCS_std**2)))

            print '    String count: ', np.sum(lengthFCS_mean[2:]), np.sqrt(np.sum(lengthFCS_std[2:]**2))

            if gethist:
                FCSdict = {'lengthFCS_mean': lengthFCS_mean,
                           'lengthFCS_std': lengthFCS_std,
                           'doping': dopings[-1],
                           'dopingerr': dopingerrs[-1],
                           'temperature': temperatures[-1],
                           'temperatureerr': temperatureerrs[-1]}
                with open("V:\Paper Data\AFM_StringTheory\FCS\\" + fn[:-4] + label + '_hist.pkl', 'wb') as f:
                    pickle.dump(FCSdict, f)

            mzs.append([mzs_system])

    # only output averages vs. doping if there are many doping or temperature values, say more than 3
    if len(dopings) > 3 or len(temperatures) > 3:
        if len(dopings) > 3 and len(dopings) == len(asls):
            # reorder values by doping; here there is just one temperature
            print "{} Ts".format(len(temperatures))
            asls = [x for _, x in sorted(zip(dopings, asls))]
            asles = [x for _, x in sorted(zip(dopings, asles))]
            scs = [x for _, x in sorted(zip(dopings, scs))]
            sces = [x for _, x in sorted(zip(dopings, sces))]
            scs0 = [x for _, x in sorted(zip(dopings, scs0))]
            sces0 = [x for _, x in sorted(zip(dopings, sces0))]
            mzs = [x for _, x in sorted(zip(dopings, mzs))]
            cfs = [x for _, x in sorted(zip(dopings, cfs))]
            cfes = [x for _, x in sorted(zip(dopings, cfes))]
            cfds = [x for _, x in sorted(zip(dopings, cfds))]
            hhs = [x for _, x in sorted(zip(dopings, hhs))]
            hhes = [x for _, x in sorted(zip(dopings, hhes))]
            g2ns = [x for _, x in sorted(zip(dopings, g2ns))]
            g2nes = [x for _, x in sorted(zip(dopings, g2nes))]
            dopingerrs = [x for _, x in sorted(zip(dopings, dopingerrs))]
            dopings = sorted(dopings)

        elif len(temperatures) > 3 and len(temperatures) == len(asls):
            # reorder values by temperature; here we reorder dopings too
            print "{}, {} ds".format(len(dopings), len(dopings[0]))
            asls = [x for _, x in sorted(zip(temperatures, asls))]
            asles = [x for _, x in sorted(zip(temperatures, asles))]
            scs = [x for _, x in sorted(zip(temperatures, scs))]
            sces = [x for _, x in sorted(zip(temperatures, sces))]
            scs0 = [x for _, x in sorted(zip(temperatures, scs0))]
            sces0 = [x for _, x in sorted(zip(temperatures, sces0))]
            mzs = [x for _, x in sorted(zip(temperatures, mzs))]
            cfs = [x for _, x in sorted(zip(temperatures, cfs))]
            cfes = [x for _, x in sorted(zip(temperatures, cfes))]
            cfds = [x for _, x in sorted(zip(temperatures, cfds))]
            hhs = [x for _, x in sorted(zip(temperatures, hhs))]
            hhes = [x for _, x in sorted(zip(temperatures, hhes))]
            g2ns = [x for _, x in sorted(zip(temperatures, g2ns))]
            g2nes = [x for _, x in sorted(zip(temperatures, g2nes))]
            dopings = [x for _, x in sorted(zip(temperatures, dopings))]
            dopingerrs = [x for _, x in sorted(zip(temperatures, dopingerrs))]
            temperatureerrs = [x for _, x in sorted(zip(temperatures, temperatureerrs))]
            temperatures = sorted(temperatures)

        FCSdict = {'doping': np.array(dopings),
                   'dopingerr': np.array(dopingerrs),
                   'temperatures': np.array(temperatures),
                   'temperatureerrs': np.array(temperatureerrs),
                   'asl': np.array(asls),
                   'aslerr': np.array(asles),
                   'sc': np.array(scs),
                   'sces': np.array(sces),
                   'sc0': np.array(scs0),
                   'sces0': np.array(sces0),
                   'mzs': mzs,
                   'cfs': cfs,
                   'cfes': cfes,
                   'cfds': cfds,
                   'hhs': hhs,
                   'hhes': hhes,
                   'g2ns': g2ns,
                   'g2nes': g2nes}

        with open("V:\Paper Data\AFM_StringTheory\FCS\\" + case + label + '_avgs.pkl', 'wb') as f:
            pickle.dump(FCSdict, f)


def shotEvalstrlenhistonly(fn, window, ps):
    with open(fn, 'rb') as f:
        ams = pickle.load(f)

    amsb = list(ams[1]) + list(ams[2])

    mzs_system, ms, mDiffs, mzs_window = findBestWindow(amsb, window)

    print '  Sorting snapshots!'
    indices = sorted(range(len(mzs_window)), key=lambda k: -abs(mzs_window[k]))
    mzs_window = [mzs_window[i] for i in indices]
    mDiffs = [mDiffs[i] for i in indices]
    ms = [ms[i] for i in indices]

    numberShots = int(ps * len(mzs_window) + 0.5)
    print '    Keeping best {} snapshots of all {}'.format(numberShots, len(mDiffs))
    print '    Lowest staggered magnetization accepted: ' + str(abs(mzs_window[numberShots - 1]))

    lengthFCS = []
    print '  Extracting string patterns!'
    for l in range(numberShots):
        if l % 500 == 0:
            print '    ...searching shot no. ' + str(l)

        mDiff = np.copy(mDiffs[l])
        m = np.copy(ms[l])
        stringLengths_temp = []

        # find all red sites initially
        redSites = np.transpose(np.where(mDiff == 2)).tolist()

        # group all red sites into red areas, and from these look for and undo string patterns
        while len(redSites) > 0:
            redAreas = []
            while len(redSites) > 0:
                red = redSites[0]
                redArea = findRedArea(red, [], redSites, mDiff, mDiff.shape[0])
                redAreas.append(redArea)
                for redSite in redArea:
                    redSites.remove(redSite)

                longestLength, ar, darkFlip, m, mDiff = possibleString(redArea, m, mDiff)
                if not darkFlip:
                    stringLengths_temp.append(longestLength)

            redSites = np.transpose(np.where(mDiff == 2)).tolist()  # find all remaining red sites and repeat

        lMax = 8
        lengthFCS.append([stringLengths_temp.count(i) * 1.0 / np.count_nonzero(m) for i in range(lMax + 1)])

    # get string length distribution
    lengthFCS = np.transpose(lengthFCS)
    lengthFCS_mean = []
    lengthFCS_std = []
    for i in range(lMax + 1):
        lengthFCS_mean.append(np.mean(lengthFCS[i]))
        lengthFCS_std.append(np.std(lengthFCS[i]) * 1.0 / np.sqrt(numberShots - 1))
    lengthFCS_mean = np.array(lengthFCS_mean)
    lengthFCS_std = np.array(lengthFCS_std)

    FCSdict = {'lengthFCS_mean': lengthFCS_mean,
               'lengthFCS_std': lengthFCS_std,
               'doping': 0,
               'dopingerr': 0,
               'temperature': 0.01,
               'temperatureerr': 0}
    with open("/volumes/Lithium/Paper Data/AFM_StringTheory/FCS/HeisenbergQMC-simulateExp/QMC_doping0_T0.01_ps0.6_nodh.pickle",
              'wb') as f:
        pickle.dump(FCSdict, f)


if __name__ == '__main__':
    dir = "V:\\Paper Data\AFM_StringTheory\\"
    ps = 0.6

    # Fig 2, 3A+B, 4, 5
    dis = ['Snapshots_experiment\\', 'Snapshots_experiment\\',
           'Snapshots_sprinkled\\', 'Snapshots_sprinkled\\',
           'Snapshots_strings\\', 'Snapshots_strings\\']
    cases = ['cold', 'hot',
             'peak0_cold', 'peak0_hot',
             'gst0.60J_cold', 'gst0.80J_hot']
    gethist = True
    window = 'reg'
    for di, case in zip(dis, cases):
        shotEval(dir+di, case, gethist, window, ps, label='')

    # Fig 3C
    dis = ['Snapshots_experiment\\', 'Snapshots_sprinkled\\']
    cases = ['D10', 'peak0_D10']
    gethist = False
    window = 'reg'
    for di, case in zip(dis, cases):
       shotEval(dir+di, case, gethist, window, ps, label='')

    # Fig S2
    dis = ['Snapshots_experiment\\', 'Snapshots_sprinkled\\', 'Snapshots_strings\\']
    cases = ['cold', 'peak0_cold', 'gst0.60J_cold']
    gethist = False
    windows = ['small', 'big', 'fixed', 'reg', 'reg']
    pss = [0.6, 0.6, 0.6, 0.4, 0.8]
    for di, case in zip(dis, cases):
        for window, ps in zip(windows, pss):
            shotEval(dir+di, case, gethist, window, ps, label='_{}_ps{}'.format(window, ps))

    # Fig S6
    dis = ['Snapshots_experiment\\', 'Snapshots_experiment\\', 'Snapshots_experiment\\',
           'Snapshots_sprinkled\\', 'Snapshots_sprinkled\\', 'Snapshots_sprinkled\\']
    cases = ['D0', 'D6', 'D12',
             'peak0_D0', 'peak0_D6', 'peak0_D12']
    gethist = False
    window = 'reg'
    ps = 0.6
    for di, case in zip(dis, cases):
        shotEval(dir+di, case, gethist, window, ps, label='')

    # Fig S7
    di = 'Snapshots_strings\\'
    cases = ['gst0.65J_cold', 'gst0.55J_cold',
             'pR0.8_cold', 'pR0.5_cold',
             'peak1_cold', 'peak2_cold', 'peak3_cold', 'peakinf_cold']
    gethist = False
    window = 'reg'
    ps = 0.6
    for case in cases:
        shotEval(dir+di, case, gethist, window, ps, label='')

